{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# How to deploy Large Language Models (LLMs) to Amazon SageMaker using new Hugging Face LLM DLC\n",
    "\n",
    "This is an example on how to deploy the open-source LLMs, like [BLOOM](bigscience/bloom) to Amazon SageMaker for inference using the new Hugging Face LLM Inference Container. We will deploy the 12B [Open Assistant Model](OpenAssistant/pythia-12b-sft-v8-7k-steps) an open-source Chat LLM trained by the Open Assistant initiative.\n",
    "\n",
    "The example covers:\n",
    "1. [Setup development environment](#1-setup-development-environment)\n",
    "2. [Retrieve the new Hugging Face LLM DLC](#2-retrieve-the-new-hugging-face-llm-dlc)\n",
    "3. [Deploy Open Assistant 12B to Amazon SageMaker](#3-deploy-open-assistant-12b-to-amazon-sagemaker)\n",
    "4. [Run inference and chat with our model](#4-run-inference-and-chat-with-our-model)\n",
    "5. [Clean up](#5-clean-up)\n",
    "\n",
    "## What is Hugging Face LLM Inference DLC?\n",
    "\n",
    "Hugging Face LLM DLC is a new purpose-built Inference Container to easily deploy LLMs in a secure and managed environment. The DLC is powered by [Text Generation Inference (TGI)](https://github.com/huggingface/text-generation-inference), an open-source, purpose-built solution for deploying and serving Large Language Models (LLMs). TGI enables high-performance text generation using Tensor Parallelism and dynamic batching for the most popular open-source LLMs, including StarCoder, BLOOM, GPT-NeoX, Llama, and T5. \n",
    "Text Generation Inference is already used by customers such as IBM, Grammarly, and the Open-Assistant initiative implements optimization for all supported model architectures, including:\n",
    "* Tensor Parallelism and custom cuda kernels\n",
    "* Optimized transformers code for inference using [flash-attention](https://github.com/HazyResearch/flash-attention) on the most popular architectures\n",
    "* Quantization with [bitsandbytes](https://github.com/TimDettmers/bitsandbytes)\n",
    "* [Continuous batching of incoming requests](https://github.com/huggingface/text-generation-inference/tree/main/router) for increased total throughput\n",
    "* Accelerated weight loading (start-up time) with [safetensors](https://github.com/huggingface/safetensors)\n",
    "* Logits warpers (temperature scaling, topk, repetition penalty ...)\n",
    "* Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)\n",
    "* Stop sequences, Log probabilities\n",
    "* Token streaming using Server-Sent Events (SSE)\n",
    "\n",
    "Officially supported model architectures are currently: \n",
    "* [BLOOM](https://huggingface.co/bigscience/bloom) / [BLOOMZ](https://huggingface.co/bigscience/bloomz)\n",
    "* [MT0-XXL](https://huggingface.co/bigscience/mt0-xxl)\n",
    "* [Galactica](https://huggingface.co/facebook/galactica-120b)\n",
    "* [SantaCoder](https://huggingface.co/bigcode/santacoder)\n",
    "* [GPT-Neox 20B](https://huggingface.co/EleutherAI/gpt-neox-20b) (joi, pythia, lotus, rosey, chip, RedPajama, open assistant)\n",
    "* [FLAN-T5-XXL](https://huggingface.co/google/flan-t5-xxl) (T5-11B)\n",
    "* [Llama](https://github.com/facebookresearch/llama) (vicuna, alpaca, koala)\n",
    "* [Starcoder](https://huggingface.co/bigcode/starcoder) / [SantaCoder](https://huggingface.co/bigcode/santacoder)\n",
    "* [Falcon 7B](https://huggingface.co/tiiuae/falcon-7b) / [Falcon 40B](https://huggingface.co/tiiuae/falcon-40b)\n",
    "\n",
    "With the new Hugging Face LLM Inference DLCs on Amazon SageMaker, AWS customers can benefit from the same technologies that power highly concurrent, low latency LLM experiences like [HuggingChat](https://hf.co/chat), [OpenAssistant](https://open-assistant.io/), and Inference API for LLM models on the Hugging Face Hub. \n",
    "\n",
    "Lets get started!\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Setup development environment\n",
    "\n",
    "We are going to use the `sagemaker` python SDK to deploy BLOOM to Amazon SageMaker. We need to make sure to have an AWS account configured and the `sagemaker` python SDK installed. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install \"sagemaker==2.163.0\" --upgrade --quiet\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If you are going to use Sagemaker in a local environment. You need access to an IAM Role with the required permissions for Sagemaker. You can find [here](https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-roles.html) more about it.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Couldn't call 'get_role' to get Role ARN from role name philippschmid to get Role path.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "sagemaker role arn: arn:aws:iam::558105141721:role/sagemaker_execution_role\n",
      "sagemaker session region: us-east-1\n"
     ]
    }
   ],
   "source": [
    "import sagemaker\n",
    "import boto3\n",
    "sess = sagemaker.Session()\n",
    "# sagemaker session bucket -> used for uploading data, models and logs\n",
    "# sagemaker will automatically create this bucket if it not exists\n",
    "sagemaker_session_bucket=None\n",
    "if sagemaker_session_bucket is None and sess is not None:\n",
    "    # set to default bucket if a bucket name is not given\n",
    "    sagemaker_session_bucket = sess.default_bucket()\n",
    "\n",
    "try:\n",
    "    role = sagemaker.get_execution_role()\n",
    "except ValueError:\n",
    "    iam = boto3.client('iam')\n",
    "    role = iam.get_role(RoleName='sagemaker_execution_role')['Role']['Arn']\n",
    "\n",
    "sess = sagemaker.Session(default_bucket=sagemaker_session_bucket)\n",
    "\n",
    "print(f\"sagemaker role arn: {role}\")\n",
    "print(f\"sagemaker session region: {sess.boto_region_name}\")\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Retrieve the new Hugging Face LLM DLC\n",
    "\n",
    "Compared to deploying regular Hugging Face models we first need to retrieve the container uri and provide it to our `HuggingFaceModel` model class with a `image_uri` pointing to the image. To retrieve the new Hugging Face LLM DLC in Amazon SageMaker, we can use the `get_huggingface_llm_image_uri` method provided by the `sagemaker` SDK. This method allows us to retrieve the URI for the desired Hugging Face LLM DLC based on the specified `backend`, `session`, `region`, and `version`. You can find the available versions [here](https://github.com/aws/deep-learning-containers/blob/master/available_images.md#huggingface-text-generation-inference-containers)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "llm image uri: 763104351884.dkr.ecr.us-east-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.0.0-tgi0.6.0-gpu-py39-cu118-ubuntu20.04\n"
     ]
    }
   ],
   "source": [
    "from sagemaker.huggingface import get_huggingface_llm_image_uri\n",
    "\n",
    "# retrieve the llm image uri\n",
    "llm_image = get_huggingface_llm_image_uri(\n",
    "  \"huggingface\",\n",
    "  version=\"0.8.2\"\n",
    ")\n",
    "\n",
    "# print ecr image uri\n",
    "print(f\"llm image uri: {llm_image}\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Deploy Deploy Open Assistant 12B to Amazon SageMaker\n",
    "\n",
    "To deploy [Open Assistant Model](OpenAssistant/pythia-12b-sft-v8-7k-steps) to Amazon SageMaker we create a `HuggingFaceModel` model class and define our endpoint configuration including the `hf_model_id`, `instance_type` etc. We will use a `g5.12xlarge` instance type, which has 4 NVIDIA A10G GPUs and 96GB of GPU memory.\n",
    "\n",
    "_Note: We could also optimize the deployment for cost and use `g5.2xlarge` instance type and enable int-8 quantization._"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "from sagemaker.huggingface import HuggingFaceModel\n",
    "\n",
    "# sagemaker config\n",
    "instance_type = \"ml.g5.12xlarge\"\n",
    "number_of_gpu = 4\n",
    "health_check_timeout = 300\n",
    "\n",
    "# Define Model and Endpoint configuration parameter\n",
    "config = {\n",
    "  'HF_MODEL_ID': \"OpenAssistant/pythia-12b-sft-v8-7k-steps\", # model_id from hf.co/models\n",
    "  'SM_NUM_GPUS': json.dumps(number_of_gpu), # Number of GPU used per replica\n",
    "  'MAX_INPUT_LENGTH': json.dumps(1024),  # Max length of input text\n",
    "  'MAX_TOTAL_TOKENS': json.dumps(2048),  # Max length of the generation (including input text)\n",
    "  # 'HF_MODEL_QUANTIZE': \"bitsandbytes\", # comment in to quantize\n",
    "}\n",
    "\n",
    "# create HuggingFaceModel with the image uri\n",
    "llm_model = HuggingFaceModel(\n",
    "  role=role,\n",
    "  image_uri=llm_image,\n",
    "  env=config\n",
    ")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "After we have created the `HuggingFaceModel` we can deploy it to Amazon SageMaker using the `deploy` method. We will deploy the model with the `ml.g5.12xlarge` instance type. TGI will automatically distribute and shard the model across all GPUs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "-------------!"
     ]
    }
   ],
   "source": [
    "# Deploy model to an endpoint\n",
    "# https://sagemaker.readthedocs.io/en/stable/api/inference/model.html#sagemaker.model.Model.deploy\n",
    "llm = llm_model.deploy(\n",
    "  initial_instance_count=1,\n",
    "  instance_type=instance_type,\n",
    "  # volume_size=400, # If using an instance with local SSD storage, volume_size must be None, e.g. p4 but not p3\n",
    "  container_startup_health_check_timeout=health_check_timeout, # 10 minutes to be able to load the model\n",
    ")\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "SageMaker will now create our endpoint and deploy the model to it. This can takes a 10-15 minutes. "
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. Run inference and chat with our model\n",
    "\n",
    "After our endpoint is deployed we can run inference on it. We will use the `predict` method from the `predictor` to run inference on our endpoint. We can inference with different parameters to impact the generation. Parameters can be defined as in the `parameters` attribute of the payload. As of today the TGI supports the following parameters:\n",
    "* `temperature`: Controls randomness in the model. Lower values will make the model more deterministic and higher values will make the model more random. Default value is 1.0.\n",
    "* `max_new_tokens`: The maximum number of tokens to generate. Default value is 20, max value is 512.\n",
    "* `repetition_penalty`: Controls the likelihood of repetition, defaults to `null`.\n",
    "* `seed`: The seed to use for random generation, default is `null`.\n",
    "* `stop`: A list of tokens to stop the generation. The generation will stop when one of the tokens is generated.\n",
    "* `top_k`: The number of highest probability vocabulary tokens to keep for top-k-filtering. Default value is `null`, which disables top-k-filtering.\n",
    "* `top_p`: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling, default to `null`\n",
    "* `do_sample`: Whether or not to use sampling ; use greedy decoding otherwise. Default value is `false`.\n",
    "* `best_of`: Generate best_of sequences and return the one if the highest token logprobs, default to `null`.\n",
    "* `details`: Whether or not to return details about the generation. Default value is `false`.\n",
    "* `return_full_text`: Whether or not to return the full text or only the generated part. Default value is `false`.\n",
    "* `truncate`: Whether or not to truncate the input to the maximum length of the model. Default value is `true`.\n",
    "* `typical_p`: The typical probability of a token. Default value is `null`.\n",
    "* `watermark`: The watermark to use for the generation. Default value is `false`.\n",
    "\n",
    "You can find the open api specification of the TGI in the [swagger documentation](https://huggingface.github.io/text-generation-inference/)\n",
    "\n",
    "The `OpenAssistant/pythia-12b-sft-v8-7k-steps` is a conversational chat model meaning we can chat with it using the following prompt:\n",
    "  \n",
    "```\n",
    "<|prompter|>[Instruction]<|endoftext|>\n",
    "<|assistant|>\n",
    "```\n",
    "\n",
    "lets give it a first try and ask about some cool ideas to do in the summer:\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<|prompter|>What are some cool ideas to do in the summer?<|endoftext|><|assistant|>There are many fun and exciting things you can do in the summer. Here are some ideas:\n",
      "\n"
     ]
    }
   ],
   "source": [
    "chat = llm.predict({\n",
    "\t\"inputs\": \"\"\"<|prompter|>What are some cool ideas to do in the summer?<|endoftext|><|assistant|>\"\"\"\n",
    "})\n",
    "\n",
    "print(chat[0][\"generated_text\"])"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we will run inference with different parameters to impact the generation. Parameters can be defined as in the `parameters` attribute of the payload. This can be used to have the model stop the generation after the turn of the `bot`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<|prompter|>How can i stay more active during winter? Give me 3 tips.<|endoftext|><|assistant|>1.  Get outside and go for a walk in the snow!\n",
      "2.  Go sledding!\n",
      "3.  Go skiing!\n"
     ]
    }
   ],
   "source": [
    "# define payload\n",
    "prompt=\"\"\"<|prompter|>How can i stay more active during winter? Give me 3 tips.<|endoftext|><|assistant|>\"\"\"\n",
    "\n",
    "# hyperparameters for llm\n",
    "payload = {\n",
    "  \"inputs\": prompt,\n",
    "  \"parameters\": {\n",
    "    \"do_sample\": True,\n",
    "    \"top_p\": 0.7,\n",
    "    \"temperature\": 0.7,\n",
    "    \"top_k\": 50,\n",
    "    \"max_new_tokens\": 256,\n",
    "    \"repetition_penalty\": 1.03,\n",
    "    \"stop\": [\"<|endoftext|>\"]\n",
    "  }\n",
    "}\n",
    "\n",
    "# send request to endpoint\n",
    "response = llm.predict(payload)\n",
    "\n",
    "# print(response[0][\"generated_text\"][:-len(\"<human>:\")])\n",
    "print(response[0][\"generated_text\"])"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Nice, now lets build a quick gradio application to chat with it. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install gradio  --upgrade"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running on local URL:  http://127.0.0.1:7862\n",
      "Running on public URL: https://48c4c5606be32ee0a0.gradio.live\n",
      "\n",
      "This share link expires in 72 hours. For free permanent hosting and GPU upgrades (NEW!), check out Spaces: https://huggingface.co/spaces\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div><iframe src=\"https://48c4c5606be32ee0a0.gradio.live\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": []
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import gradio as gr\n",
    "\n",
    "# hyperparameters for llm\n",
    "parameters = {\n",
    "    \"do_sample\": True,\n",
    "    \"top_p\": 0.7,\n",
    "    \"temperature\": 0.7,\n",
    "    \"top_k\": 50,\n",
    "    \"max_new_tokens\": 256,\n",
    "    \"repetition_penalty\": 1.03,\n",
    "    \"stop\": [\"<|endoftext|>\"]\n",
    "  }\n",
    "\n",
    "with gr.Blocks() as demo:\n",
    "    gr.Markdown(\"## Chat with Amazon SageMaker\")\n",
    "    with gr.Column():\n",
    "        chatbot = gr.Chatbot()\n",
    "        with gr.Row():\n",
    "            with gr.Column():\n",
    "                message = gr.Textbox(label=\"Chat Message Box\", placeholder=\"Chat Message Box\", show_label=False)\n",
    "            with gr.Column():\n",
    "                with gr.Row():\n",
    "                    submit = gr.Button(\"Submit\")\n",
    "                    clear = gr.Button(\"Clear\")\n",
    "                    \n",
    "    def respond(message, chat_history):\n",
    "        # convert chat history to prompt\n",
    "        converted_chat_history = \"\"\n",
    "        if len(chat_history) > 0:\n",
    "          for c in chat_history:\n",
    "            converted_chat_history += f\"<|prompter|>{c[0]}<|endoftext|><|assistant|>{c[1]}<|endoftext|>\"\n",
    "        prompt = f\"{converted_chat_history}<|prompter|>{message}<|endoftext|><|assistant|>\"\n",
    "        \n",
    "        # send request to endpoint\n",
    "        llm_response = llm.predict({\"inputs\": prompt, \"parameters\": parameters})        \n",
    "        \n",
    "        # remove prompt from response\n",
    "        parsed_response = llm_response[0][\"generated_text\"][len(prompt):]        \n",
    "        chat_history.append((message, parsed_response))\n",
    "        return \"\", chat_history\n",
    "\n",
    "    submit.click(respond, [message, chatbot], [message, chatbot], queue=False)\n",
    "    clear.click(lambda: None, None, chatbot, queue=False)\n",
    "\n",
    "demo.launch(share=True)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![example](assets/gradio.png)\n",
    "\n",
    "Awesome! 🚀 We have successfully deployed Open Assistant Model to Amazon SageMaker and run inference on it. Additionally, we have built a quick gradio application to chat with our model. \n",
    "\n",
    "Now, its time for you to try it out yourself and build Generation AI applications with the new Hugging Face LLM DLC on Amazon SageMaker."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 6. Clean up\n",
    "\n",
    "To clean up, we can delete the model and endpoint.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "llm.delete_model()\n",
    "llm.delete_endpoint()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "hf",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "5fcf248a74081676ead7e77f54b2c239ba2921b952f7cbcdbbe5427323165924"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
